Skip to content

feat: qwen3.5 perf opt#1351

Open
blueswhen wants to merge 1 commit into
mainfrom
opt_qwen35
Open

feat: qwen3.5 perf opt#1351
blueswhen wants to merge 1 commit into
mainfrom
opt_qwen35

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance optimizations and features, including prioritizing the fa3 attention backend, implementing fast CUDA graph planning for FlashInfer, adding a fused add_rmsnorm Triton kernel, and optimizing the Qwen3Next model's GDN decode path and sampling backend. The code review feedback highlights critical correctness and performance improvements: reversing in-place logit division if FlashInfer sampling fails, skipping temperature scaling for greedy sampling when logprobs are not needed, removing redundant loops and memory operations in the fused RMSNorm kernel, avoiding redundant concurrent writes in the MoE alignment kernel, asserting tensor contiguity in GDN decode kernels, and replacing manual sigmoid implementations with Triton's built-in tl.sigmoid.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +174 to +183
if temperature != 1.0:
logits.div_(temperature)

if top_k == vocab_size and top_p != 1.0:
top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor)

top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device)
return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In _try_flashinfer_sample_without_penalty, logits.div_(temperature) is performed in-place. If the subsequent flashinfer sampling function returns None (e.g., because flashinfer is not available or fails), the function returns None, and the caller falls back to standard sampling. However, logits has already been modified in-place, so the caller will divide logits by the temperature again, leading to incorrect sampling probabilities. We should restore logits by multiplying back by temperature if the flashinfer call returns None.

Suggested change
if temperature != 1.0:
logits.div_(temperature)
if top_k == vocab_size and top_p != 1.0:
top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor)
top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device)
return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor)
if top_k == vocab_size and top_p != 1.0:
top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
if temperature != 1.0:
logits.div_(temperature)
res = _flashinfer_top_p_sample_from_logits(logits, top_p_tensor)
if res is None and temperature != 1.0:
logits.mul_(temperature)
return res
top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device)
top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device)
if temperature != 1.0:
logits.div_(temperature)
res = _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor)
if res is None and temperature != 1.0:
logits.mul_(temperature)
return res

Comment on lines +52 to +95
@triton.jit
def _add_rms_norm_fwd_fused(
X,
R,
Y,
W,
x_stride0,
x_stride1,
r_stride0,
r_stride1,
y_stride0,
y_stride1,
N,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
X += row * x_stride0
R += row * r_stride0
Y += row * y_stride0

_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
x = x + r
tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask)
_var += x * x

var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
y = x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
y *= w
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _add_rms_norm_fwd_fused kernel uses a loop over N with BLOCK_SIZE steps. However, N > BLOCK_SIZE is explicitly forbidden and raises a RuntimeError in the Python wrapper. Since N <= BLOCK_SIZE is guaranteed, the loop is completely redundant. More importantly, storing X and then reloading it in the second loop causes unnecessary global memory read/write operations, which is extremely slow for memory-bandwidth-bound kernels like RMSNorm. We can load X only once and avoid the loop entirely.

@triton.jit
def _add_rms_norm_fwd_fused(
    X,
    R,
    Y,
    W,
    x_stride0,
    x_stride1,
    r_stride0,
    r_stride1,
    y_stride0,
    y_stride1,
    N,
    eps,
    HAS_WEIGHT: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    X += row * x_stride0
    R += row * r_stride0
    Y += row * y_stride0

    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N
    x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
    r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
    x = x + r
    tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask)

    var = tl.sum(x * x, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    y = x * rstd
    if HAS_WEIGHT:
        w = tl.load(W + cols, mask=mask).to(tl.float32)
        y *= w
    tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)

Comment on lines +95 to +105
if b_temperatures is not None:
logits.div_(b_temperatures.view((-1, 1)))

if is_all_greedy and not need_logprobs:
batch_next_token_ids = torch.argmax(logits, -1)
if get_env_start_args().mtp_mode:
batch_next_token_logprobs = torch.zeros(
batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device
)
return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1)
return batch_next_token_ids.view(-1), None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The in-place division logits.div_(b_temperatures.view((-1, 1))) is performed before checking if the request is greedy and doesn't need logprobs. For greedy sampling, temperature scaling does not affect the argmax result. Performing this division on a large logits tensor (e.g., vocab size 150k+) is a significant waste of GPU compute and memory bandwidth. Reordering the greedy check to return early before performing the division will improve performance.

Suggested change
if b_temperatures is not None:
logits.div_(b_temperatures.view((-1, 1)))
if is_all_greedy and not need_logprobs:
batch_next_token_ids = torch.argmax(logits, -1)
if get_env_start_args().mtp_mode:
batch_next_token_logprobs = torch.zeros(
batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device
)
return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1)
return batch_next_token_ids.view(-1), None
if is_all_greedy and not need_logprobs:
batch_next_token_ids = torch.argmax(logits, -1)
if get_env_start_args().mtp_mode:
batch_next_token_logprobs = torch.zeros(
batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device
)
return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1)
return batch_next_token_ids.view(-1), None
if b_temperatures is not None:
logits.div_(b_temperatures.view((-1, 1)))

Comment on lines +231 to +233
if ZERO_EXPERT_TOKEN_NUM:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In moe_align_fused_kernel, if ZERO_EXPERT_TOKEN_NUM is True, every program block in the grid will concurrently write 0 to expert_token_num_ptr. While functionally correct, this is redundant and causes unnecessary memory write contention. Adding and token_block == 0 to the condition ensures only the first block performs the zeroing.

Suggested change
if ZERO_EXPERT_TOKEN_NUM:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)
if ZERO_EXPERT_TOKEN_NUM and token_block == 0:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)

Comment on lines +69 to +78
def pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In pack_gdn_decode_inputs, z_raw is loaded in the Triton kernel using flattened indexing (z_raw + row * stride_z_b + qkv_offsets), which assumes that the last two dimensions of z_raw are contiguous. However, there is no assertion in the Python wrapper to ensure z_raw is contiguous. If a non-contiguous tensor is passed, it will silently produce incorrect results. We should add an assertion to prevent this.

Suggested change
def pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):
@torch.no_grad()
def pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):
assert z_raw.is_contiguous(), "z_raw must be contiguous"

Comment on lines +214 to +228
def conv_pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
conv_state: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
conv_state_indices: torch.Tensor,
activation: str,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In conv_pack_gdn_decode_inputs, z_raw is loaded in the Triton kernel using flattened indexing, which assumes that the last two dimensions of z_raw are contiguous. However, there is no assertion in the Python wrapper to ensure z_raw is contiguous. If a non-contiguous tensor is passed, it will silently produce incorrect results. We should add an assertion to prevent this.

Suggested change
def conv_pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
conv_state: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
conv_state_indices: torch.Tensor,
activation: str,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):
@torch.no_grad()
def conv_pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
conv_state: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
conv_state_indices: torch.Tensor,
activation: str,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):
assert z_raw.is_contiguous(), "z_raw must be contiguous"

mask=offs_dim < dim_end,
other=0.0,
).to(tl.float32)
gate = 1.0 / (1.0 + tl.exp(-gate))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual implementation of sigmoid 1.0 / (1.0 + tl.exp(-gate)) is used. Triton has a built-in tl.sigmoid function which is cleaner, more readable, and can be better optimized by the compiler.

Suggested change
gate = 1.0 / (1.0 + tl.exp(-gate))
gate = tl.sigmoid(gate)

else:
gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32)
hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32)
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual implementation of sigmoid 1.0 / (1.0 + tl.exp(-gate_vals)) is used. Triton has a built-in tl.sigmoid function which is cleaner, more readable, and can be better optimized by the compiler.

Suggested change
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))
gate_vals = tl.sigmoid(gate_vals)

gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32)
else:
gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32)
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual implementation of sigmoid 1.0 / (1.0 + tl.exp(-gate_vals)) is used. Triton has a built-in tl.sigmoid function which is cleaner, more readable, and can be better optimized by the compiler.

Suggested change
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))
gate_vals = tl.sigmoid(gate_vals)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant